{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "This is a companion notebook for the book [Deep Learning with Python, Third Edition](https://www.manning.com/books/deep-learning-with-python-third-edition). For readability, it only contains runnable code blocks and section titles, and omits everything else in the book: text paragraphs, figures, and pseudocode.\n\n**If you want to be able to follow what's going on, I recommend reading the notebook side by side with your copy of the book.**\n\nThe book's contents are available online at [deeplearningwithpython.io](https://deeplearningwithpython.io)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "!pip install keras keras-hub --upgrade -q"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ[\"KERAS_BACKEND\"] = \"jax\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "cellView": "form",
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "# @title\n",
    "import os\n",
    "from IPython.core.magic import register_cell_magic\n",
    "\n",
    "@register_cell_magic\n",
    "def backend(line, cell):\n",
    "    current, required = os.environ.get(\"KERAS_BACKEND\", \"\"), line.split()[-1]\n",
    "    if current == required:\n",
    "        get_ipython().run_cell(cell)\n",
    "    else:\n",
    "        print(\n",
    "            f\"This cell requires the {required} backend. To run it, change KERAS_BACKEND to \"\n",
    "            f\"\\\"{required}\\\" at the top of the notebook, restart the runtime, and rerun the notebook.\"\n",
    "        )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Text classification"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "### A brief history of natural language processing"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "### Preparing text data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "import regex as re\n",
    "\n",
    "def split_chars(text):\n",
    "    return re.findall(r\".\", text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "chars = split_chars(\"The quick brown fox jumped over the lazy dog.\")\n",
    "chars[:12]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "def split_words(text):\n",
    "    return re.findall(r\"[\\w]+|[.,!?;]\", text)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "split_words(\"The quick brown fox jumped over the dog.\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "vocabulary = {\n",
    "    \"[UNK]\": 0,\n",
    "    \"the\": 1,\n",
    "    \"quick\": 2,\n",
    "    \"brown\": 3,\n",
    "    \"fox\": 4,\n",
    "    \"jumped\": 5,\n",
    "    \"over\": 6,\n",
    "    \"dog\": 7,\n",
    "    \".\": 8,\n",
    "}\n",
    "words = split_words(\"The quick brown fox jumped over the lazy dog.\")\n",
    "indices = [vocabulary.get(word, 0) for word in words]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "#### Character and word tokenization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "class CharTokenizer:\n",
    "    def __init__(self, vocabulary):\n",
    "        self.vocabulary = vocabulary\n",
    "        self.unk_id = vocabulary[\"[UNK]\"]\n",
    "\n",
    "    def standardize(self, inputs):\n",
    "        return inputs.lower()\n",
    "\n",
    "    def split(self, inputs):\n",
    "        return re.findall(r\".\", inputs)\n",
    "\n",
    "    def index(self, tokens):\n",
    "        return [self.vocabulary.get(t, self.unk_id) for t in tokens]\n",
    "\n",
    "    def __call__(self, inputs):\n",
    "        inputs = self.standardize(inputs)\n",
    "        tokens = self.split(inputs)\n",
    "        indices = self.index(tokens)\n",
    "        return indices"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "import collections\n",
    "\n",
    "def compute_char_vocabulary(inputs, max_size):\n",
    "    char_counts = collections.Counter()\n",
    "    for x in inputs:\n",
    "        x = x.lower()\n",
    "        tokens = re.findall(r\".\", x)\n",
    "        char_counts.update(tokens)\n",
    "    vocabulary = [\"[UNK]\"]\n",
    "    most_common = char_counts.most_common(max_size - len(vocabulary))\n",
    "    for token, count in most_common:\n",
    "        vocabulary.append(token)\n",
    "    return dict((token, i) for i, token in enumerate(vocabulary))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "class WordTokenizer:\n",
    "    def __init__(self, vocabulary):\n",
    "        self.vocabulary = vocabulary\n",
    "        self.unk_id = vocabulary[\"[UNK]\"]\n",
    "\n",
    "    def standardize(self, inputs):\n",
    "        return inputs.lower()\n",
    "\n",
    "    def split(self, inputs):\n",
    "        return re.findall(r\"[\\w]+|[.,!?;]\", inputs)\n",
    "\n",
    "    def index(self, tokens):\n",
    "        return [self.vocabulary.get(t, self.unk_id) for t in tokens]\n",
    "\n",
    "    def __call__(self, inputs):\n",
    "        inputs = self.standardize(inputs)\n",
    "        tokens = self.split(inputs)\n",
    "        indices = self.index(tokens)\n",
    "        return indices"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "def compute_word_vocabulary(inputs, max_size):\n",
    "    word_counts = collections.Counter()\n",
    "    for x in inputs:\n",
    "        x = x.lower()\n",
    "        tokens = re.findall(r\"[\\w]+|[.,!?;]\", x)\n",
    "        word_counts.update(tokens)\n",
    "    vocabulary = [\"[UNK]\"]\n",
    "    most_common = word_counts.most_common(max_size - len(vocabulary))\n",
    "    for token, count in most_common:\n",
    "        vocabulary.append(token)\n",
    "    return dict((token, i) for i, token in enumerate(vocabulary))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "import keras\n",
    "\n",
    "filename = keras.utils.get_file(\n",
    "    origin=\"https://www.gutenberg.org/files/2701/old/moby10b.txt\",\n",
    ")\n",
    "moby_dick = list(open(filename, \"r\"))\n",
    "\n",
    "vocabulary = compute_char_vocabulary(moby_dick, max_size=100)\n",
    "char_tokenizer = CharTokenizer(vocabulary)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "print(\"Vocabulary length:\", len(vocabulary))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "print(\"Vocabulary start:\", list(vocabulary.keys())[:10])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "print(\"Vocabulary end:\", list(vocabulary.keys())[-10:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "print(\"Line length:\", len(char_tokenizer(\n",
    "   \"Call me Ishmael. Some years ago--never mind how long precisely.\"\n",
    ")))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "vocabulary = compute_word_vocabulary(moby_dick, max_size=2_000)\n",
    "word_tokenizer = WordTokenizer(vocabulary)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "print(\"Vocabulary length:\", len(vocabulary))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "print(\"Vocabulary start:\", list(vocabulary.keys())[:5])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "print(\"Vocabulary end:\", list(vocabulary.keys())[-5:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "print(\"Line length:\", len(word_tokenizer(\n",
    "   \"Call me Ishmael. Some years ago--never mind how long precisely.\"\n",
    ")))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "#### Subword tokenization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "data = [\n",
    "    \"the quick brown fox\",\n",
    "    \"the slow brown fox\",\n",
    "    \"the quick brown foxhound\",\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "def count_and_split_words(data):\n",
    "    counts = collections.Counter()\n",
    "    for line in data:\n",
    "        line = line.lower()\n",
    "        for word in re.findall(r\"[\\w]+|[.,!?;]\", line):\n",
    "            chars = re.findall(r\".\", word)\n",
    "            split_word = \" \".join(chars)\n",
    "            counts[split_word] += 1\n",
    "    return dict(counts)\n",
    "\n",
    "counts = count_and_split_words(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "counts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "def count_pairs(counts):\n",
    "    pairs = collections.Counter()\n",
    "    for word, freq in counts.items():\n",
    "        symbols = word.split()\n",
    "        for pair in zip(symbols[:-1], symbols[1:]):\n",
    "            pairs[pair] += freq\n",
    "    return pairs\n",
    "\n",
    "def merge_pair(counts, first, second):\n",
    "    split = re.compile(f\"(?<!\\S){first} {second}(?!\\S)\")\n",
    "    merged = f\"{first}{second}\"\n",
    "    return {split.sub(merged, word): count for word, count in counts.items()}\n",
    "\n",
    "for i in range(10):\n",
    "    pairs = count_pairs(counts)\n",
    "    first, second = max(pairs, key=pairs.get)\n",
    "    counts = merge_pair(counts, first, second)\n",
    "    print(list(counts.keys()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "def compute_sub_word_vocabulary(dataset, vocab_size):\n",
    "    counts = count_and_split_words(dataset)\n",
    "\n",
    "    char_counts = collections.Counter()\n",
    "    for word in counts:\n",
    "        for char in word.split():\n",
    "            char_counts[char] += counts[word]\n",
    "    most_common = char_counts.most_common()\n",
    "    vocab = [\"[UNK]\"] + [char for char, freq in most_common]\n",
    "    merges = []\n",
    "\n",
    "    while len(vocab) < vocab_size:\n",
    "        pairs = count_pairs(counts)\n",
    "        if not pairs:\n",
    "            break\n",
    "        first, second = max(pairs, key=pairs.get)\n",
    "        counts = merge_pair(counts, first, second)\n",
    "        vocab.append(f\"{first}{second}\")\n",
    "        merges.append(f\"{first} {second}\")\n",
    "\n",
    "    vocab = dict((token, index) for index, token in enumerate(vocab))\n",
    "    merges = dict((token, rank) for rank, token in enumerate(merges))\n",
    "    return vocab, merges"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "class SubWordTokenizer:\n",
    "    def __init__(self, vocabulary, merges):\n",
    "        self.vocabulary = vocabulary\n",
    "        self.merges = merges\n",
    "        self.unk_id = vocabulary[\"[UNK]\"]\n",
    "\n",
    "    def standardize(self, inputs):\n",
    "        return inputs.lower()\n",
    "\n",
    "    def bpe_merge(self, word):\n",
    "        while True:\n",
    "            pairs = re.findall(r\"(?<!\\S)\\S+ \\S+(?!\\S)\", word, overlapped=True)\n",
    "            if not pairs:\n",
    "                break\n",
    "            best = min(pairs, key=lambda pair: self.merges.get(pair, 1e9))\n",
    "            if best not in self.merges:\n",
    "                break\n",
    "            first, second = best.split()\n",
    "            split = re.compile(f\"(?<!\\S){first} {second}(?!\\S)\")\n",
    "            merged = f\"{first}{second}\"\n",
    "            word = split.sub(merged, word)\n",
    "        return word\n",
    "\n",
    "    def split(self, inputs):\n",
    "        tokens = []\n",
    "        for word in re.findall(r\"[\\w]+|[.,!?;]\", inputs):\n",
    "            word = \" \".join(re.findall(r\".\", word))\n",
    "            word = self.bpe_merge(word)\n",
    "            tokens.extend(word.split())\n",
    "        return tokens\n",
    "\n",
    "    def index(self, tokens):\n",
    "        return [self.vocabulary.get(t, self.unk_id) for t in tokens]\n",
    "\n",
    "    def __call__(self, inputs):\n",
    "        inputs = self.standardize(inputs)\n",
    "        tokens = self.split(inputs)\n",
    "        indices = self.index(tokens)\n",
    "        return indices"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "vocabulary, merges = compute_sub_word_vocabulary(moby_dick, 2_000)\n",
    "sub_word_tokenizer = SubWordTokenizer(vocabulary, merges)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "print(\"Vocabulary length:\", len(vocabulary))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "print(\"Vocabulary start:\", list(vocabulary.keys())[:10])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "print(\"Vocabulary end:\", list(vocabulary.keys())[-7:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "print(\"Line length:\", len(sub_word_tokenizer(\n",
    "   \"Call me Ishmael. Some years ago--never mind how long precisely.\"\n",
    ")))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "### Sets vs. sequences"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "#### Loading the IMDb classification dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "import os, pathlib, shutil, random\n",
    "\n",
    "zip_path = keras.utils.get_file(\n",
    "    origin=\"https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz\",\n",
    "    fname=\"imdb\",\n",
    "    extract=True,\n",
    ")\n",
    "\n",
    "imdb_extract_dir = pathlib.Path(zip_path) / \"aclImdb\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "for path in imdb_extract_dir.glob(\"*/*\"):\n",
    "    if path.is_dir():\n",
    "        print(path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "print(open(imdb_extract_dir / \"train\" / \"pos\" / \"4077_10.txt\", \"r\").read())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "train_dir = pathlib.Path(\"imdb_train\")\n",
    "test_dir = pathlib.Path(\"imdb_test\")\n",
    "val_dir = pathlib.Path(\"imdb_val\")\n",
    "\n",
    "shutil.copytree(imdb_extract_dir / \"test\", test_dir)\n",
    "\n",
    "val_percentage = 0.2\n",
    "for category in (\"neg\", \"pos\"):\n",
    "    src_dir = imdb_extract_dir / \"train\" / category\n",
    "    src_files = os.listdir(src_dir)\n",
    "    random.Random(1337).shuffle(src_files)\n",
    "    num_val_samples = int(len(src_files) * val_percentage)\n",
    "\n",
    "    os.makedirs(val_dir / category)\n",
    "    for file in src_files[:num_val_samples]:\n",
    "        shutil.copy(src_dir / file, val_dir / category / file)\n",
    "    os.makedirs(train_dir / category)\n",
    "    for file in src_files[num_val_samples:]:\n",
    "        shutil.copy(src_dir / file, train_dir / category / file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "from keras.utils import text_dataset_from_directory\n",
    "\n",
    "batch_size = 32\n",
    "train_ds = text_dataset_from_directory(train_dir, batch_size=batch_size)\n",
    "val_ds = text_dataset_from_directory(val_dir, batch_size=batch_size)\n",
    "test_ds = text_dataset_from_directory(test_dir, batch_size=batch_size)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "### Set models"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "#### Training a bag-of-words model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "from keras import layers\n",
    "\n",
    "max_tokens = 20_000\n",
    "text_vectorization = layers.TextVectorization(\n",
    "    max_tokens=max_tokens,\n",
    "    split=\"whitespace\",\n",
    "    output_mode=\"multi_hot\",\n",
    ")\n",
    "train_ds_no_labels = train_ds.map(lambda x, y: x)\n",
    "text_vectorization.adapt(train_ds_no_labels)\n",
    "\n",
    "bag_of_words_train_ds = train_ds.map(\n",
    "    lambda x, y: (text_vectorization(x), y), num_parallel_calls=8\n",
    ")\n",
    "bag_of_words_val_ds = val_ds.map(\n",
    "    lambda x, y: (text_vectorization(x), y), num_parallel_calls=8\n",
    ")\n",
    "bag_of_words_test_ds = test_ds.map(\n",
    "    lambda x, y: (text_vectorization(x), y), num_parallel_calls=8\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "x, y = next(bag_of_words_train_ds.as_numpy_iterator())\n",
    "x.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "y.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "def build_linear_classifier(max_tokens, name):\n",
    "    inputs = keras.Input(shape=(max_tokens,))\n",
    "    outputs = layers.Dense(1, activation=\"sigmoid\")(inputs)\n",
    "    model = keras.Model(inputs, outputs, name=name)\n",
    "    model.compile(\n",
    "        optimizer=\"adam\",\n",
    "        loss=\"binary_crossentropy\",\n",
    "        metrics=[\"accuracy\"],\n",
    "    )\n",
    "    return model\n",
    "\n",
    "model = build_linear_classifier(max_tokens, \"bag_of_words_classifier\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "model.summary(line_length=80)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "early_stopping = keras.callbacks.EarlyStopping(\n",
    "    monitor=\"val_loss\",\n",
    "    restore_best_weights=True,\n",
    "    patience=2,\n",
    ")\n",
    "history = model.fit(\n",
    "    bag_of_words_train_ds,\n",
    "    validation_data=bag_of_words_val_ds,\n",
    "    epochs=10,\n",
    "    callbacks=[early_stopping],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "accuracy = history.history[\"accuracy\"]\n",
    "val_accuracy = history.history[\"val_accuracy\"]\n",
    "epochs = range(1, len(accuracy) + 1)\n",
    "\n",
    "plt.plot(epochs, accuracy, \"r--\", label=\"Training accuracy\")\n",
    "plt.plot(epochs, val_accuracy, \"b\", label=\"Validation accuracy\")\n",
    "plt.title(\"Training and validation accuracy\")\n",
    "plt.legend()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "test_loss, test_acc = model.evaluate(bag_of_words_test_ds)\n",
    "test_acc"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "#### Training a bigram model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "max_tokens = 30_000\n",
    "text_vectorization = layers.TextVectorization(\n",
    "    max_tokens=max_tokens,\n",
    "    split=\"whitespace\",\n",
    "    output_mode=\"multi_hot\",\n",
    "    ngrams=2,\n",
    ")\n",
    "text_vectorization.adapt(train_ds_no_labels)\n",
    "\n",
    "bigram_train_ds = train_ds.map(\n",
    "    lambda x, y: (text_vectorization(x), y), num_parallel_calls=8\n",
    ")\n",
    "bigram_val_ds = val_ds.map(\n",
    "    lambda x, y: (text_vectorization(x), y), num_parallel_calls=8\n",
    ")\n",
    "bigram_test_ds = test_ds.map(\n",
    "    lambda x, y: (text_vectorization(x), y), num_parallel_calls=8\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "x, y = next(bigram_train_ds.as_numpy_iterator())\n",
    "x.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "text_vectorization.get_vocabulary()[100:108]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "model = build_linear_classifier(max_tokens, \"bigram_classifier\")\n",
    "model.fit(\n",
    "    bigram_train_ds,\n",
    "    validation_data=bigram_val_ds,\n",
    "    epochs=10,\n",
    "    callbacks=[early_stopping],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "test_loss, test_acc = model.evaluate(bigram_test_ds)\n",
    "test_acc"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "### Sequence models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "max_length = 600\n",
    "max_tokens = 30_000\n",
    "text_vectorization = layers.TextVectorization(\n",
    "    max_tokens=max_tokens,\n",
    "    split=\"whitespace\",\n",
    "    output_mode=\"int\",\n",
    "    output_sequence_length=max_length,\n",
    ")\n",
    "text_vectorization.adapt(train_ds_no_labels)\n",
    "\n",
    "sequence_train_ds = train_ds.map(\n",
    "    lambda x, y: (text_vectorization(x), y), num_parallel_calls=8\n",
    ")\n",
    "sequence_val_ds = val_ds.map(\n",
    "    lambda x, y: (text_vectorization(x), y), num_parallel_calls=8\n",
    ")\n",
    "sequence_test_ds = test_ds.map(\n",
    "    lambda x, y: (text_vectorization(x), y), num_parallel_calls=8\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "x, y = next(sequence_test_ds.as_numpy_iterator())\n",
    "x.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "#### Training a recurrent model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "from keras import ops\n",
    "\n",
    "class OneHotEncoding(keras.Layer):\n",
    "    def __init__(self, depth, **kwargs):\n",
    "        super().__init__(**kwargs)\n",
    "        self.depth = depth\n",
    "\n",
    "    def call(self, inputs):\n",
    "        flat_inputs = ops.reshape(ops.cast(inputs, \"int\"), [-1])\n",
    "        one_hot_vectors = ops.eye(self.depth)\n",
    "        outputs = ops.take(one_hot_vectors, flat_inputs, axis=0)\n",
    "        return ops.reshape(outputs, ops.shape(inputs) + (self.depth,))\n",
    "\n",
    "one_hot_encoding = OneHotEncoding(max_tokens)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "x, y = next(sequence_train_ds.as_numpy_iterator())\n",
    "one_hot_encoding(x).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "hidden_dim = 64\n",
    "inputs = keras.Input(shape=(max_length,), dtype=\"int32\")\n",
    "x = one_hot_encoding(inputs)\n",
    "x = layers.Bidirectional(layers.LSTM(hidden_dim))(x)\n",
    "x = layers.Dropout(0.5)(x)\n",
    "outputs = layers.Dense(1, activation=\"sigmoid\")(x)\n",
    "model = keras.Model(inputs, outputs, name=\"lstm_with_one_hot\")\n",
    "model.compile(\n",
    "    optimizer=\"adam\",\n",
    "    loss=\"binary_crossentropy\",\n",
    "    metrics=[\"accuracy\"],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "model.summary(line_length=80)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "# \u26a0\ufe0fNOTE\u26a0\ufe0f: The following fit call will error on a T4 GPU on the TensorFlow\n",
    "# backend due to a bug in TensorFlow. If you the follow cell errors out,\n",
    "# do one of the following:\n",
    "# - Skip the following two cells.\n",
    "# - Switch to the Jax or Torch backend and re-run this notebook.\n",
    "# - Change the GPU type in your runtime (requires Colab Pro as of this writing)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "model.fit(\n",
    "    sequence_train_ds,\n",
    "    validation_data=sequence_val_ds,\n",
    "    epochs=10,\n",
    "    callbacks=[early_stopping],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "test_loss, test_acc = model.evaluate(sequence_test_ds)\n",
    "test_acc"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "#### Understanding word embeddings"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "#### Using a word embedding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "hidden_dim = 64\n",
    "inputs = keras.Input(shape=(max_length,), dtype=\"int32\")\n",
    "x = keras.layers.Embedding(\n",
    "    input_dim=max_tokens,\n",
    "    output_dim=hidden_dim,\n",
    "    mask_zero=True,\n",
    ")(inputs)\n",
    "x = keras.layers.Bidirectional(keras.layers.LSTM(hidden_dim))(x)\n",
    "x = keras.layers.Dropout(0.5)(x)\n",
    "outputs = keras.layers.Dense(1, activation=\"sigmoid\")(x)\n",
    "model = keras.Model(inputs, outputs, name=\"lstm_with_embedding\")\n",
    "model.compile(\n",
    "    optimizer=\"adam\",\n",
    "    loss=\"binary_crossentropy\",\n",
    "    metrics=[\"accuracy\"],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "model.summary(line_length=80)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "model.fit(\n",
    "    sequence_train_ds,\n",
    "    validation_data=sequence_val_ds,\n",
    "    epochs=10,\n",
    "    callbacks=[early_stopping],\n",
    ")\n",
    "test_loss, test_acc = model.evaluate(sequence_test_ds)\n",
    "test_acc"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "#### Pretraining a word embedding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "imdb_vocabulary = text_vectorization.get_vocabulary()\n",
    "tokenize_no_padding = keras.layers.TextVectorization(\n",
    "    vocabulary=imdb_vocabulary,\n",
    "    split=\"whitespace\",\n",
    "    output_mode=\"int\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "\n",
    "context_size = 4\n",
    "window_size = 9\n",
    "\n",
    "def window_data(token_ids):\n",
    "    num_windows = tf.maximum(tf.size(token_ids) - context_size * 2, 0)\n",
    "    windows = tf.range(window_size)[None, :]\n",
    "    windows = windows + tf.range(num_windows)[:, None]\n",
    "    windowed_tokens = tf.gather(token_ids, windows)\n",
    "    return tf.data.Dataset.from_tensor_slices(windowed_tokens)\n",
    "\n",
    "def split_label(window):\n",
    "    left = window[:context_size]\n",
    "    right = window[context_size + 1 :]\n",
    "    bag = tf.concat((left, right), axis=0)\n",
    "    label = window[4]\n",
    "    return bag, label\n",
    "\n",
    "dataset = keras.utils.text_dataset_from_directory(\n",
    "    imdb_extract_dir / \"train\", batch_size=None\n",
    ")\n",
    "dataset = dataset.map(lambda x, y: x, num_parallel_calls=8)\n",
    "dataset = dataset.map(tokenize_no_padding, num_parallel_calls=8)\n",
    "dataset = dataset.interleave(window_data, cycle_length=8, num_parallel_calls=8)\n",
    "dataset = dataset.map(split_label, num_parallel_calls=8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "hidden_dim = 64\n",
    "inputs = keras.Input(shape=(2 * context_size,))\n",
    "cbow_embedding = layers.Embedding(\n",
    "    max_tokens,\n",
    "    hidden_dim,\n",
    ")\n",
    "x = cbow_embedding(inputs)\n",
    "x = layers.GlobalAveragePooling1D()(x)\n",
    "outputs = layers.Dense(max_tokens, activation=\"sigmoid\")(x)\n",
    "cbow_model = keras.Model(inputs, outputs)\n",
    "cbow_model.compile(\n",
    "    optimizer=\"adam\",\n",
    "    loss=\"sparse_categorical_crossentropy\",\n",
    "    metrics=[\"sparse_categorical_accuracy\"],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "cbow_model.summary(line_length=80)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "dataset = dataset.batch(1024).cache()\n",
    "cbow_model.fit(dataset, epochs=4)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "#### Using the pretrained embedding for classification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "inputs = keras.Input(shape=(max_length,))\n",
    "lstm_embedding = layers.Embedding(\n",
    "    input_dim=max_tokens,\n",
    "    output_dim=hidden_dim,\n",
    "    mask_zero=True,\n",
    ")\n",
    "x = lstm_embedding(inputs)\n",
    "x = layers.Bidirectional(layers.LSTM(hidden_dim))(x)\n",
    "x = layers.Dropout(0.5)(x)\n",
    "outputs = layers.Dense(1, activation=\"sigmoid\")(x)\n",
    "model = keras.Model(inputs, outputs, name=\"lstm_with_cbow\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "lstm_embedding.embeddings.assign(cbow_embedding.embeddings)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "model.compile(\n",
    "    optimizer=\"adam\",\n",
    "    loss=\"binary_crossentropy\",\n",
    "    metrics=[\"accuracy\"],\n",
    ")\n",
    "model.fit(\n",
    "    sequence_train_ds,\n",
    "    validation_data=sequence_val_ds,\n",
    "    epochs=10,\n",
    "    callbacks=[early_stopping],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "test_loss, test_acc = model.evaluate(sequence_test_ds)\n",
    "test_acc"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "name": "chapter14_text-classification",
   "private_outputs": false,
   "provenance": [],
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}